package CPU.rv64_1stage
import chisel3._
import chisel3.util._

object Decode {
  def apply(inst:UInt):(UInt,UInt,UInt,UInt,UInt,Bool,UInt,UInt)={
    val decoder = Module(new Decode).io
    decoder.inst := inst
    (decoder.rs1_addr,decoder.rs2_addr,decoder.rd_addr,decoder.rs1_op,decoder.rs2_op,decoder.rd_en,
      decoder.opcode,decoder.fuType)
  }
}

class Decode extends Module with HasInstrType {
  val io = IO(new Bundle{
    val inst    = Input(UInt(32.W))
    val rs1_addr= Output(UInt(5.W))
    val rs2_addr= Output(UInt(5.W))
    val rd_addr = Output(UInt(5.W))
    val rs1_op  = Output(UInt(2.W))
    val rs2_op  = Output(UInt(3.W))
    val rd_en   = Output(Bool())
    val opcode  = Output(UInt(7.W))
    val fuType  = Output(UInt(3.W))
    val instType= Output(UInt(4.W))
  })

  val inst = io.inst
  val Decode_signals = ListLookup(inst,RVIInstr.DecodeDefault,RVIInstr.table)
  val instrType :: fuType :: fuOpType :: Nil = Decode_signals

  io.opcode := fuOpType
  io.fuType := fuType
  io.instType := instrType
  io.rs1_op := MuxCase(0.U,Array(
    (instrType === InstrR)  -> 1.U,   //choose rs1
    (instrType === InstrI)  -> 1.U,
    (instrType === InstrU && !inst(5))  -> 2.U,  //choose pc
    (instrType === InstrB)  -> 1.U,
    (instrType === InstrJ)  -> 2.U,
    (instrType === InstrS)  -> 1.U,
  ))

  io.rs2_op := MuxCase(0.U,Array(
    (instrType === InstrR) -> 1.U,   //choose rs2
    (instrType === InstrI) -> 2.U,   // imm_i
    (instrType === InstrU) -> 3.U,   // imm_u
    (instrType === InstrB) -> 1.U,
    (instrType === InstrJ) -> 4.U,   // +4
    (instrType === InstrS) -> 5.U,
  ))

  io.rd_en  := isrfWen(instrType)

  io.rs1_addr := inst(19,15)
  io.rs2_addr := inst(24,20)
  io.rd_addr  := inst(11,7)
}

trait HasInstrType {
  def InstrN  = "b0000".U
  def InstrI  = "b0100".U
  def InstrR  = "b0101".U
  def InstrS  = "b0010".U
  def InstrB  = "b0001".U
  def InstrU  = "b0110".U
  def InstrJ  = "b0111".U
  def InstrA  = "b1110".U
  def InstrSA = "b1111".U // Atom Inst: SC

  def isrfWen(instrType : UInt): Bool = instrType(2)
}

object FuType {
  def num = 5
  def alu = "b000".U
  def lsu = "b001".U
  def mdu = "b010".U
  def csr = "b011".U
  def mou = "b100".U
  def bru =  alu
}

object ALUOpType {
  def add  = "b1000000".U
  def sll  = "b0000001".U
  def slt  = "b0000010".U
  def sltu = "b0000011".U
  def xor  = "b0000100".U
  def srl  = "b0000101".U
  def or   = "b0000110".U
  def and  = "b0000111".U
  def sub  = "b0001000".U
  def sra  = "b0001101".U

  def addw = "b1100000".U
  def subw = "b0101000".U
  def sllw = "b0100001".U
  def srlw = "b0100101".U
  def sraw = "b0101101".U

  def isWordOp(func: UInt) = func(5)

  def jal  = "b1011000".U
  def jalr = "b1011010".U
  def beq  = "b0010000".U
  def bne  = "b0010001".U
  def blt  = "b0010100".U
  def bge  = "b0010101".U
  def bltu = "b0010110".U
  def bgeu = "b0010111".U

  // for RAS
  def call = "b1011100".U
  def ret  = "b1011110".U

  def isAdd(func: UInt) = func(6)
  def pcPlus2(func: UInt) = func(5)
  def isBru(func: UInt) = func(4)
  def isBranch(func: UInt) = !func(3)
  def isJump(func: UInt) = isBru(func) && !isBranch(func)
  def getBranchType(func: UInt) = func(2, 1)
  def isBranchInvert(func: UInt) = func(0)
}

object LSUOpType { //TODO: refactor LSU fuop
  def lb   = "b0000000".U
  def lh   = "b0000001".U
  def lw   = "b0000010".U
  def ld   = "b0000011".U
  def lbu  = "b0000100".U
  def lhu  = "b0000101".U
  def lwu  = "b0000110".U
  def sb   = "b0001000".U
  def sh   = "b0001001".U
  def sw   = "b0001010".U
  def sd   = "b0001011".U

  def lr      = "b0100000".U
  def sc      = "b0100001".U
  def amoswap = "b0100010".U
  def amoadd  = "b1100011".U
  def amoxor  = "b0100100".U
  def amoand  = "b0100101".U
  def amoor   = "b0100110".U
  def amomin  = "b0110111".U
  def amomax  = "b0110000".U
  def amominu = "b0110001".U
  def amomaxu = "b0110010".U

  def isAdd(func: UInt) = func(6)
  def isAtom(func: UInt): Bool = func(5)
  def isStore(func: UInt): Bool = func(3)
  def isLoad(func: UInt): Bool = !isStore(func) & !isAtom(func)
  def isLR(func: UInt): Bool = func === lr
  def isSC(func: UInt): Bool = func === sc
  def isAMO(func: UInt): Bool = isAtom(func) && !isLR(func) && !isSC(func)

  def needMemRead(func: UInt): Bool = isLoad(func) || isAMO(func) || isLR(func)
  def needMemWrite(func: UInt): Bool = isStore(func) || isAMO(func) || isSC(func)

  def atomW = "010".U
  def atomD = "011".U
}

object CSROpType {
  def jmp  = "b000".U
  def wrt  = "b001".U
  def set  = "b010".U
  def clr  = "b011".U
  def wrti = "b101".U
  def seti = "b110".U
  def clri = "b111".U
}